Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Lowering for FlexAttention Backwards #125515

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented May 3, 2024

Summary

What does this PR do?

It enables Inductor to actually generate the fused flex attention kernel for the backwards

I did some other things along the way:

  • Abstract out the 'build_subgraph_buffer' subroutine and make it reusable between flex attention and flex_attention backwards. In total we need too build 3 subgraphs for fwd + bwd. 1 for the fwd graph and then 2 in the bwd. The FAv2 algorithm recomputes the parts of the forward (more efficiently since we already have the row_max via logsumexp), therefore we need to inline both the fwd graph and the joint graph in the bwds kernel.
  • The version of the backwards kernel is from a somewhat older version of the triton tutorial implementation. I think that we should update in a follow up to a newer version. Notably the blocks need to be square for this to work as currently implemented. I am sure there are many opportunities for optimization.
  • I didnt correctly register the decomp table + IndexMode when I landed: Add backwards support to FlexAttention #123902, this remedies that.
  • The rel_bias helper func was reversed in terms of causality. I updated and then add a test specific for "future causal" attention.
  • This PRs but the main point that I think still needs to be worked out is the store_output call. I have it hacked up to be 'fake' but I dont think we want to land that and likely want to just have a mutated 'dq' and a stored_output 'dk'
  • I also needed to update the TritonTemplateKernel to actually accept multiple subgraphs (modifications)
  • I updated the benchmark to also profile bwds performance

Benchmark Numbers:

The current implementation is not parallelizing over ctx length in the bwd
FWD Speedups

Type Speedup shape score_mod dtype
Average 0.991
Max 1.182 (16, 16, 4096, 64) noop torch.bfloat16
Min 0.796 (2, 16, 512, 256) head_bias torch.bfloat16

BWD Speedups

Type Speedup shape score_mod dtype
Average 0.291
Max 0.652 (8, 16, 512, 64) head_bias torch.bfloat16
Min 0.073 (2, 16, 4096, 128) head_bias torch.bfloat16
Full Data
shape score_mod dtype fwd_eager_time fwd_compiled_time bwd_eager_time bwd_compiled_time fwd_speedup bwd_speedup
(2, 16, 512, 64) noop torch.bfloat16 19.936 19.092 57.851 193.564 1.044 0.299
(2, 16, 512, 64) causal_mask torch.bfloat16 19.955 19.497 57.662 206.278 1.024 0.280
(2, 16, 512, 64) relative_bias torch.bfloat16 19.455 21.297 57.674 195.219 0.913 0.295
(2, 16, 512, 64) head_bias torch.bfloat16 19.958 21.289 57.674 193.859 0.938 0.298
(2, 16, 512, 128) noop torch.bfloat16 28.157 28.615 82.831 454.211 0.984 0.182
(2, 16, 512, 128) causal_mask torch.bfloat16 28.154 28.444 83.091 432.083 0.990 0.192
(2, 16, 512, 128) relative_bias torch.bfloat16 28.722 27.897 83.175 446.789 1.030 0.186
(2, 16, 512, 128) head_bias torch.bfloat16 28.299 27.673 83.052 459.179 1.023 0.181
(2, 16, 512, 256) noop torch.bfloat16 41.167 50.504 175.019 1083.545 0.815 0.162
(2, 16, 512, 256) causal_mask torch.bfloat16 41.656 51.933 175.078 1171.176 0.802 0.149
(2, 16, 512, 256) relative_bias torch.bfloat16 41.697 50.722 175.159 1097.312 0.822 0.160
(2, 16, 512, 256) head_bias torch.bfloat16 41.690 52.387 175.184 1097.336 0.796 0.160
(2, 16, 1024, 64) noop torch.bfloat16 39.232 37.454 127.847 612.430 1.047 0.209
(2, 16, 1024, 64) causal_mask torch.bfloat16 39.930 39.599 127.755 665.359 1.008 0.192
(2, 16, 1024, 64) relative_bias torch.bfloat16 39.417 41.304 127.902 614.990 0.954 0.208
(2, 16, 1024, 64) head_bias torch.bfloat16 39.965 42.034 127.953 613.273 0.951 0.209
(2, 16, 1024, 128) noop torch.bfloat16 63.964 71.024 226.510 1637.669 0.901 0.138
(2, 16, 1024, 128) causal_mask torch.bfloat16 63.843 72.451 226.750 1558.949 0.881 0.145
(2, 16, 1024, 128) relative_bias torch.bfloat16 64.301 70.487 226.651 1610.063 0.912 0.141
(2, 16, 1024, 128) head_bias torch.bfloat16 64.033 71.394 226.676 1668.511 0.897 0.136
(2, 16, 1024, 256) noop torch.bfloat16 129.348 141.390 507.337 4405.175 0.915 0.115
(2, 16, 1024, 256) causal_mask torch.bfloat16 129.538 145.680 507.178 4768.874 0.889 0.106
(2, 16, 1024, 256) relative_bias torch.bfloat16 129.438 142.782 507.004 4401.002 0.907 0.115
(2, 16, 1024, 256) head_bias torch.bfloat16 129.058 146.242 507.547 4434.251 0.883 0.114
(2, 16, 4096, 64) noop torch.bfloat16 481.606 409.120 1440.890 14147.269 1.177 0.102
(2, 16, 4096, 64) causal_mask torch.bfloat16 480.227 438.847 1434.419 14973.386 1.094 0.096
(2, 16, 4096, 64) relative_bias torch.bfloat16 480.831 458.104 1432.935 14193.253 1.050 0.101
(2, 16, 4096, 64) head_bias torch.bfloat16 480.749 452.497 1437.040 14084.869 1.062 0.102
(2, 16, 4096, 128) noop torch.bfloat16 872.534 848.275 2600.895 35156.849 1.029 0.074
(2, 16, 4096, 128) causal_mask torch.bfloat16 872.647 868.279 2587.581 31919.531 1.005 0.081
(2, 16, 4096, 128) relative_bias torch.bfloat16 871.484 827.644 2593.989 34805.634 1.053 0.075
(2, 16, 4096, 128) head_bias torch.bfloat16 871.422 856.437 2602.482 35708.591 1.017 0.073
(2, 16, 4096, 256) noop torch.bfloat16 1904.497 1758.183 6122.416 66754.593 1.083 0.092
(2, 16, 4096, 256) causal_mask torch.bfloat16 1911.174 1762.821 6113.207 72759.392 1.084 0.084
(2, 16, 4096, 256) relative_bias torch.bfloat16 1911.254 1727.108 6123.530 66577.988 1.107 0.092
(2, 16, 4096, 256) head_bias torch.bfloat16 1916.977 1801.804 6118.158 67359.680 1.064 0.091
(8, 16, 512, 64) noop torch.bfloat16 44.984 43.974 170.276 262.259 1.023 0.649
(8, 16, 512, 64) causal_mask torch.bfloat16 45.001 46.265 170.509 274.893 0.973 0.620
(8, 16, 512, 64) relative_bias torch.bfloat16 45.466 48.211 170.606 262.759 0.943 0.649
(8, 16, 512, 64) head_bias torch.bfloat16 45.481 48.435 170.267 261.265 0.939 0.652
(8, 16, 512, 128) noop torch.bfloat16 72.565 74.736 313.220 773.126 0.971 0.405
(8, 16, 512, 128) causal_mask torch.bfloat16 72.015 75.755 313.311 775.513 0.951 0.404
(8, 16, 512, 128) relative_bias torch.bfloat16 72.105 74.189 313.806 769.238 0.972 0.408
(8, 16, 512, 128) head_bias torch.bfloat16 72.005 74.364 313.509 775.237 0.968 0.404
(8, 16, 512, 256) noop torch.bfloat16 138.656 165.453 663.707 2672.067 0.838 0.248
(8, 16, 512, 256) causal_mask torch.bfloat16 139.096 172.613 663.593 2926.538 0.806 0.227
(8, 16, 512, 256) relative_bias torch.bfloat16 139.500 168.417 663.938 2658.629 0.828 0.250
(8, 16, 512, 256) head_bias torch.bfloat16 139.776 173.549 662.920 2667.266 0.805 0.249
(8, 16, 1024, 64) noop torch.bfloat16 134.883 125.004 484.706 1195.254 1.079 0.406
(8, 16, 1024, 64) causal_mask torch.bfloat16 134.297 132.875 485.420 1234.953 1.011 0.393
(8, 16, 1024, 64) relative_bias torch.bfloat16 134.839 139.231 485.470 1198.556 0.968 0.405
(8, 16, 1024, 64) head_bias torch.bfloat16 133.822 136.449 485.608 1189.198 0.981 0.408
(8, 16, 1024, 128) noop torch.bfloat16 235.470 234.765 886.094 2662.944 1.003 0.333
(8, 16, 1024, 128) causal_mask torch.bfloat16 236.305 241.382 886.293 2646.984 0.979 0.335
(8, 16, 1024, 128) relative_bias torch.bfloat16 236.414 233.980 885.250 2642.178 1.010 0.335
(8, 16, 1024, 128) head_bias torch.bfloat16 237.176 239.040 885.754 2665.242 0.992 0.332
(8, 16, 1024, 256) noop torch.bfloat16 504.445 517.855 1978.956 9592.906 0.974 0.206
(8, 16, 1024, 256) causal_mask torch.bfloat16 502.428 536.002 1978.611 10607.342 0.937 0.187
(8, 16, 1024, 256) relative_bias torch.bfloat16 503.396 523.960 1977.993 9539.284 0.961 0.207
(8, 16, 1024, 256) head_bias torch.bfloat16 503.818 536.014 1980.131 9576.262 0.940 0.207
(8, 16, 4096, 64) noop torch.bfloat16 1970.139 1674.930 5750.940 16724.134 1.176 0.344
(8, 16, 4096, 64) causal_mask torch.bfloat16 1959.036 1775.056 5780.512 17390.350 1.104 0.332
(8, 16, 4096, 64) relative_bias torch.bfloat16 1947.198 1773.869 5780.643 16779.699 1.098 0.345
(8, 16, 4096, 64) head_bias torch.bfloat16 1963.935 1829.502 5780.018 16703.259 1.073 0.346
(8, 16, 4096, 128) noop torch.bfloat16 3582.711 3362.623 10436.069 36415.565 1.065 0.287
(8, 16, 4096, 128) causal_mask torch.bfloat16 3581.504 3499.472 10346.869 36164.959 1.023 0.286
(8, 16, 4096, 128) relative_bias torch.bfloat16 3589.779 3337.849 10529.621 36261.696 1.075 0.290
(8, 16, 4096, 128) head_bias torch.bfloat16 3602.265 3436.444 10458.660 36507.790 1.048 0.286
(8, 16, 4096, 256) noop torch.bfloat16 7695.923 7126.275 24643.009 140949.081 1.080 0.175
(8, 16, 4096, 256) causal_mask torch.bfloat16 7679.939 7186.252 24538.105 157156.067 1.069 0.156
(8, 16, 4096, 256) relative_bias torch.bfloat16 7681.374 6994.832 24549.713 140077.179 1.098 0.175
(8, 16, 4096, 256) head_bias torch.bfloat16 7679.822 7212.278 24627.823 140675.003 1.065 0.175
(16, 16, 512, 64) noop torch.bfloat16 80.126 78.291 333.719 541.165 1.023 0.617
(16, 16, 512, 64) causal_mask torch.bfloat16 80.065 81.696 333.779 551.113 0.980 0.606
(16, 16, 512, 64) relative_bias torch.bfloat16 80.138 86.715 333.364 542.118 0.924 0.615
(16, 16, 512, 64) head_bias torch.bfloat16 80.415 85.204 333.294 536.840 0.944 0.621
(16, 16, 512, 128) noop torch.bfloat16 134.964 138.025 607.093 1333.102 0.978 0.455
(16, 16, 512, 128) causal_mask torch.bfloat16 134.192 141.523 606.269 1424.318 0.948 0.426
(16, 16, 512, 128) relative_bias torch.bfloat16 135.711 138.639 606.283 1327.974 0.979 0.457
(16, 16, 512, 128) head_bias torch.bfloat16 135.552 140.555 607.107 1347.370 0.964 0.451
(16, 16, 512, 256) noop torch.bfloat16 275.113 315.144 1301.583 5268.153 0.873 0.247
(16, 16, 512, 256) causal_mask torch.bfloat16 274.867 328.106 1302.513 5770.594 0.838 0.226
(16, 16, 512, 256) relative_bias torch.bfloat16 276.052 321.770 1302.904 5241.920 0.858 0.249
(16, 16, 512, 256) head_bias torch.bfloat16 271.409 328.839 1302.142 5266.037 0.825 0.247
(16, 16, 1024, 64) noop torch.bfloat16 260.489 237.463 955.884 1817.558 1.097 0.526
(16, 16, 1024, 64) causal_mask torch.bfloat16 262.378 254.350 955.280 1843.807 1.032 0.518
(16, 16, 1024, 64) relative_bias torch.bfloat16 261.338 268.253 956.038 1820.036 0.974 0.525
(16, 16, 1024, 64) head_bias torch.bfloat16 262.153 264.156 956.023 1810.076 0.992 0.528
(16, 16, 1024, 128) noop torch.bfloat16 476.475 461.413 1760.578 4306.521 1.033 0.409
(16, 16, 1024, 128) causal_mask torch.bfloat16 473.794 479.178 1761.277 4619.439 0.989 0.381
(16, 16, 1024, 128) relative_bias torch.bfloat16 473.839 463.282 1758.692 4290.562 1.023 0.410
(16, 16, 1024, 128) head_bias torch.bfloat16 472.979 472.896 1763.086 4367.931 1.000 0.404
(16, 16, 1024, 256) noop torch.bfloat16 1014.184 1026.764 3922.997 19104.147 0.988 0.205
(16, 16, 1024, 256) causal_mask torch.bfloat16 1013.217 1039.046 3928.382 21086.281 0.975 0.186
(16, 16, 1024, 256) relative_bias torch.bfloat16 1008.519 1015.278 3922.133 18980.652 0.993 0.207
(16, 16, 1024, 256) head_bias torch.bfloat16 1011.360 1047.542 3931.245 19069.172 0.965 0.206
(16, 16, 4096, 64) noop torch.bfloat16 3929.850 3325.667 11411.704 23344.280 1.182 0.489
(16, 16, 4096, 64) causal_mask torch.bfloat16 3885.262 3581.544 11390.515 23725.639 1.085 0.480
(16, 16, 4096, 64) relative_bias torch.bfloat16 3865.737 3537.308 11489.901 23406.330 1.093 0.491
(16, 16, 4096, 64) head_bias torch.bfloat16 3880.530 3665.249 11484.411 23299.496 1.059 0.493
(16, 16, 4096, 128) noop torch.bfloat16 7030.306 6745.715 20621.264 57464.096 1.042 0.359
(16, 16, 4096, 128) causal_mask torch.bfloat16 7095.414 7034.385 20410.656 61660.511 1.009 0.331
(16, 16, 4096, 128) relative_bias torch.bfloat16 7084.779 6686.497 20315.161 57243.969 1.060 0.355
(16, 16, 4096, 128) head_bias torch.bfloat16 7075.367 6863.305 20494.385 58481.953 1.031 0.350
(16, 16, 4096, 256) noop torch.bfloat16 15612.741 14297.482 55306.847 281161.865 1.092 0.197
(16, 16, 4096, 256) causal_mask torch.bfloat16 15326.592 14263.878 55227.806 313063.232 1.075 0.176
(16, 16, 4096, 256) relative_bias torch.bfloat16 15297.963 14007.379 54558.029 279529.175 1.092 0.195
(16, 16, 4096, 256) head_bias torch.bfloat16 15216.160 14276.027 55081.581 280996.826 1.066 0.196

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

Copy link

pytorch-bot bot commented May 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125515

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (7 Unrelated Failures)

As of commit 69e8b7a with merge base 9689532 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@drisspg drisspg force-pushed the add-lowering-to-flex-attention-backkwards branch 6 times, most recently from 8333271 to b725537 Compare May 7, 2024 17:12
torch/_inductor/select_algorithm.py Outdated Show resolved Hide resolved
torch/_inductor/select_algorithm.py Show resolved Hide resolved
torch/_inductor/kernel/flex_attention.py Outdated Show resolved Hide resolved
@drisspg drisspg force-pushed the add-lowering-to-flex-attention-backkwards branch 4 times, most recently from aa5cfda to 9d10e81 Compare May 9, 2024 00:19
@drisspg drisspg force-pushed the add-lowering-to-flex-attention-backkwards branch from 9d10e81 to a317d9d Compare May 9, 2024 01:20
@@ -320,6 +328,8 @@ def store_output(
indices: Union[List[Any], Tuple[Any]],
val: str,
mask: Optional[str] = None,
indent_width: int = 4,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to figure out this one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want indent width, but probably definitely not fake_output lol

@drisspg drisspg force-pushed the add-lowering-to-flex-attention-backkwards branch 7 times, most recently from d996891 to 50a2da9 Compare May 11, 2024 17:51
@drisspg drisspg marked this pull request as ready for review May 11, 2024 17:55
@drisspg drisspg force-pushed the add-lowering-to-flex-attention-backkwards branch from 50a2da9 to 8159d4a Compare May 11, 2024 23:58
@pytorchmergebot
Copy link
Collaborator

@drisspg your PR has been successfully reverted.

@drisspg
Copy link
Contributor Author

drisspg commented May 16, 2024

@huydhn hmmm so what exactly do I need to do?

@clee2000
Copy link
Contributor

@huydhn hmmm so what exactly do I need to do?

If you do mean to use >6GB of memory, do something like #126399

(Note for @clee2000, I remember we are still using 2 parallel processed on CUDA sm86 jobs. So this is just FYI and probably not related to your change to increase that number)

We're using 3 for sm86

@drisspg drisspg force-pushed the add-lowering-to-flex-attention-backkwards branch from c7968cd to 69e8b7a Compare May 16, 2024 18:49
@drisspg drisspg requested a review from a team as a code owner May 16, 2024 18:49
@drisspg
Copy link
Contributor Author

drisspg commented May 16, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@drisspg
Copy link
Contributor Author

drisspg commented May 16, 2024

@pytorchbot -h

Copy link

pytorch-bot bot commented May 16, 2024

PyTorchBot Help

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci,cherry-pick,close} ...

In order to invoke the bot on your PR, include a line that starts with
@pytorchbot anywhere in a comment. That line will form the command; no
multi-line commands are allowed. Some commands may be used on issues as specified below.

Example:
    Some extra context, blah blah, wow this PR looks awesome

    @pytorchbot merge

optional arguments:
  -h, --help            Show this help message and exit.

command:
  {merge,revert,rebase,label,drci,cherry-pick,close}
    merge               Merge a PR
    revert              Revert a PR
    rebase              Rebase a PR
    label               Add label to a PR
    drci                Update Dr. CI
    cherry-pick         Cherry pick a PR onto a release branch
    close               Close a PR

Merge

usage: @pytorchbot merge [-f MESSAGE | -i] [-ic] [-r [{viable/strict,main}]]

Merge an accepted PR, subject to the rules in .github/merge_rules.json.
By default, this will wait for all required checks (lint, pull) to succeed before merging.

optional arguments:
  -f MESSAGE, --force MESSAGE
                        Merge without checking anything. This requires a reason for auditting purpose, for example:
                        @pytorchbot merge -f 'Minor update to fix lint. Expecting all PR tests to pass'
                        
                        Please use `-f` as last resort, prefer `--ignore-current` to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.
  -i, --ignore-current  Merge while ignoring the currently failing jobs.  Behaves like -f if there are no pending jobs.
  -ic                   Old flag for --ignore-current. Deprecated in favor of -i.
  -r [{viable/strict,main}], --rebase [{viable/strict,main}]
                        Rebase the PR to re run checks before merging.  Accepts viable/strict or main as branch options and will default to viable/strict if not specified.

Revert

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Revert a merged PR. This requires that you are a Meta employee.

Example:
  @pytorchbot revert -m="This is breaking tests on trunk. hud.pytorch.org/" -c=nosignal

optional arguments:
  -m MESSAGE, --message MESSAGE
                        The reason you are reverting, will be put in the commit message. Must be longer than 3 words.
  -c {nosignal,ignoredsignal,landrace,weird,ghfirst}, --classification {nosignal,ignoredsignal,landrace,weird,ghfirst}
                        A machine-friendly classification of the revert reason.

Rebase

usage: @pytorchbot rebase [-s | -b BRANCH]

Rebase a PR. Rebasing defaults to the stable viable/strict branch of pytorch.
Repeat contributor may use this command to rebase their PR.

optional arguments:
  -s, --stable          [DEPRECATED] Rebase onto viable/strict
  -b BRANCH, --branch BRANCH
                        Branch you would like to rebase to

Label

usage: @pytorchbot label labels [labels ...]

Adds label to a PR or Issue [Can be used on Issues]

positional arguments:
  labels  Labels to add to given Pull Request or Issue [Can be used on Issues]

Dr CI

usage: @pytorchbot drci 

Update Dr. CI. Updates the Dr. CI comment on the PR in case it's gotten out of sync with actual CI results.

cherry-pick

usage: @pytorchbot cherry-pick --onto ONTO [--fixes FIXES] -c
                               {regression,critical,fixnewfeature,docs,release}

Cherry pick a pull request onto a release branch for inclusion in a release

optional arguments:
  --onto ONTO           Branch you would like to cherry pick onto (Example: release/2.1)
  --fixes FIXES         Link to the issue that your PR fixes (Example: https://github.com/pytorch/pytorch/issues/110666)
  -c {regression,critical,fixnewfeature,docs,release}, --classification {regression,critical,fixnewfeature,docs,release}
                        A machine-friendly classification of the cherry-pick reason.

Close

usage: @pytorchbot close

Close a PR [Can be used on issues]

@drisspg
Copy link
Contributor Author

drisspg commented May 16, 2024

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: Lint / lintrunner-noclang / linux-job

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@drisspg
Copy link
Contributor Author

drisspg commented May 17, 2024

@pytorchbot merge -f "spoke with Catherine, I added all of flex's test to the serial list, and these failures look unrelated"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

ZelboK pushed a commit to ZelboK/pytorch that referenced this pull request May 19, 2024
# Summary
#### What does this PR do?
It enables Inductor to actually generate the fused flex attention kernel for the backwards

I did some other things along the way:
- Abstract out the 'build_subgraph_buffer' subroutine and make it reusable between flex attention and flex_attention backwards. In total we need too build 3 subgraphs for fwd + bwd. 1 for the fwd graph and then 2 in the bwd. The FAv2 algorithm recomputes the parts of the forward (more efficiently since we already have the row_max via logsumexp), therefore we need to inline both the fwd graph and the joint graph in the bwds kernel.
- The version of the backwards kernel is from a somewhat older version of the triton tutorial implementation. I think that we should update in a follow up to a newer version. Notably the blocks need to be square for this to work as currently implemented. I am sure there are many opportunities for optimization.
- I didnt correctly register the decomp table + IndexMode when I landed: pytorch#123902, this remedies that.
- The rel_bias helper func was reversed in terms of causality. I updated and then add a test specific for "future causal" attention.
- This PRs but the main point that I think still needs to be worked out is the store_output call. I have it hacked up to be 'fake' but I dont think we want to land that and likely want to just have a mutated 'dq' and a stored_output 'dk'
- I also needed to update the `TritonTemplateKernel` to actually accept multiple subgraphs (modifications)
- I updated the benchmark to also profile bwds performance

### Benchmark Numbers:
_The current implementation is not parallelizing over ctx length in the bwd_
FWD Speedups

| Type    |   Speedup | shape              | score_mod   | dtype          |
|---------|-----------|--------------------|-------------|----------------|
| Average |     0.991 |                    |             |                |
| Max     |     1.182 | (16, 16, 4096, 64) | noop        | torch.bfloat16 |
| Min     |     0.796 | (2, 16, 512, 256)  | head_bias   | torch.bfloat16 |

BWD Speedups

| Type    |   Speedup | shape              | score_mod   | dtype          |
|---------|-----------|--------------------|-------------|----------------|
| Average |     0.291 |                    |             |                |
| Max     |     0.652 | (8, 16, 512, 64)   | head_bias   | torch.bfloat16 |
| Min     |     0.073 | (2, 16, 4096, 128) | head_bias   | torch.bfloat16 |

<details>

<summary>Full Data</summary>

| shape               | score_mod     | dtype          |   fwd_eager_time |   fwd_compiled_time |   bwd_eager_time |   bwd_compiled_time |   fwd_speedup |   bwd_speedup |
|---------------------|---------------|----------------|------------------|---------------------|------------------|---------------------|---------------|---------------|
| (2, 16, 512, 64)    | noop          | torch.bfloat16 |           19.936 |              19.092 |           57.851 |             193.564 |         1.044 |         0.299 |
| (2, 16, 512, 64)    | causal_mask   | torch.bfloat16 |           19.955 |              19.497 |           57.662 |             206.278 |         1.024 |         0.280 |
| (2, 16, 512, 64)    | relative_bias | torch.bfloat16 |           19.455 |              21.297 |           57.674 |             195.219 |         0.913 |         0.295 |
| (2, 16, 512, 64)    | head_bias     | torch.bfloat16 |           19.958 |              21.289 |           57.674 |             193.859 |         0.938 |         0.298 |
| (2, 16, 512, 128)   | noop          | torch.bfloat16 |           28.157 |              28.615 |           82.831 |             454.211 |         0.984 |         0.182 |
| (2, 16, 512, 128)   | causal_mask   | torch.bfloat16 |           28.154 |              28.444 |           83.091 |             432.083 |         0.990 |         0.192 |
| (2, 16, 512, 128)   | relative_bias | torch.bfloat16 |           28.722 |              27.897 |           83.175 |             446.789 |         1.030 |         0.186 |
| (2, 16, 512, 128)   | head_bias     | torch.bfloat16 |           28.299 |              27.673 |           83.052 |             459.179 |         1.023 |         0.181 |
| (2, 16, 512, 256)   | noop          | torch.bfloat16 |           41.167 |              50.504 |          175.019 |            1083.545 |         0.815 |         0.162 |
| (2, 16, 512, 256)   | causal_mask   | torch.bfloat16 |           41.656 |              51.933 |          175.078 |            1171.176 |         0.802 |         0.149 |
| (2, 16, 512, 256)   | relative_bias | torch.bfloat16 |           41.697 |              50.722 |          175.159 |            1097.312 |         0.822 |         0.160 |
| (2, 16, 512, 256)   | head_bias     | torch.bfloat16 |           41.690 |              52.387 |          175.184 |            1097.336 |         0.796 |         0.160 |
| (2, 16, 1024, 64)   | noop          | torch.bfloat16 |           39.232 |              37.454 |          127.847 |             612.430 |         1.047 |         0.209 |
| (2, 16, 1024, 64)   | causal_mask   | torch.bfloat16 |           39.930 |              39.599 |          127.755 |             665.359 |         1.008 |         0.192 |
| (2, 16, 1024, 64)   | relative_bias | torch.bfloat16 |           39.417 |              41.304 |          127.902 |             614.990 |         0.954 |         0.208 |
| (2, 16, 1024, 64)   | head_bias     | torch.bfloat16 |           39.965 |              42.034 |          127.953 |             613.273 |         0.951 |         0.209 |
| (2, 16, 1024, 128)  | noop          | torch.bfloat16 |           63.964 |              71.024 |          226.510 |            1637.669 |         0.901 |         0.138 |
| (2, 16, 1024, 128)  | causal_mask   | torch.bfloat16 |           63.843 |              72.451 |          226.750 |            1558.949 |         0.881 |         0.145 |
| (2, 16, 1024, 128)  | relative_bias | torch.bfloat16 |           64.301 |              70.487 |          226.651 |            1610.063 |         0.912 |         0.141 |
| (2, 16, 1024, 128)  | head_bias     | torch.bfloat16 |           64.033 |              71.394 |          226.676 |            1668.511 |         0.897 |         0.136 |
| (2, 16, 1024, 256)  | noop          | torch.bfloat16 |          129.348 |             141.390 |          507.337 |            4405.175 |         0.915 |         0.115 |
| (2, 16, 1024, 256)  | causal_mask   | torch.bfloat16 |          129.538 |             145.680 |          507.178 |            4768.874 |         0.889 |         0.106 |
| (2, 16, 1024, 256)  | relative_bias | torch.bfloat16 |          129.438 |             142.782 |          507.004 |            4401.002 |         0.907 |         0.115 |
| (2, 16, 1024, 256)  | head_bias     | torch.bfloat16 |          129.058 |             146.242 |          507.547 |            4434.251 |         0.883 |         0.114 |
| (2, 16, 4096, 64)   | noop          | torch.bfloat16 |          481.606 |             409.120 |         1440.890 |           14147.269 |         1.177 |         0.102 |
| (2, 16, 4096, 64)   | causal_mask   | torch.bfloat16 |          480.227 |             438.847 |         1434.419 |           14973.386 |         1.094 |         0.096 |
| (2, 16, 4096, 64)   | relative_bias | torch.bfloat16 |          480.831 |             458.104 |         1432.935 |           14193.253 |         1.050 |         0.101 |
| (2, 16, 4096, 64)   | head_bias     | torch.bfloat16 |          480.749 |             452.497 |         1437.040 |           14084.869 |         1.062 |         0.102 |
| (2, 16, 4096, 128)  | noop          | torch.bfloat16 |          872.534 |             848.275 |         2600.895 |           35156.849 |         1.029 |         0.074 |
| (2, 16, 4096, 128)  | causal_mask   | torch.bfloat16 |          872.647 |             868.279 |         2587.581 |           31919.531 |         1.005 |         0.081 |
| (2, 16, 4096, 128)  | relative_bias | torch.bfloat16 |          871.484 |             827.644 |         2593.989 |           34805.634 |         1.053 |         0.075 |
| (2, 16, 4096, 128)  | head_bias     | torch.bfloat16 |          871.422 |             856.437 |         2602.482 |           35708.591 |         1.017 |         0.073 |
| (2, 16, 4096, 256)  | noop          | torch.bfloat16 |         1904.497 |            1758.183 |         6122.416 |           66754.593 |         1.083 |         0.092 |
| (2, 16, 4096, 256)  | causal_mask   | torch.bfloat16 |         1911.174 |            1762.821 |         6113.207 |           72759.392 |         1.084 |         0.084 |
| (2, 16, 4096, 256)  | relative_bias | torch.bfloat16 |         1911.254 |            1727.108 |         6123.530 |           66577.988 |         1.107 |         0.092 |
| (2, 16, 4096, 256)  | head_bias     | torch.bfloat16 |         1916.977 |            1801.804 |         6118.158 |           67359.680 |         1.064 |         0.091 |
| (8, 16, 512, 64)    | noop          | torch.bfloat16 |           44.984 |              43.974 |          170.276 |             262.259 |         1.023 |         0.649 |
| (8, 16, 512, 64)    | causal_mask   | torch.bfloat16 |           45.001 |              46.265 |          170.509 |             274.893 |         0.973 |         0.620 |
| (8, 16, 512, 64)    | relative_bias | torch.bfloat16 |           45.466 |              48.211 |          170.606 |             262.759 |         0.943 |         0.649 |
| (8, 16, 512, 64)    | head_bias     | torch.bfloat16 |           45.481 |              48.435 |          170.267 |             261.265 |         0.939 |         0.652 |
| (8, 16, 512, 128)   | noop          | torch.bfloat16 |           72.565 |              74.736 |          313.220 |             773.126 |         0.971 |         0.405 |
| (8, 16, 512, 128)   | causal_mask   | torch.bfloat16 |           72.015 |              75.755 |          313.311 |             775.513 |         0.951 |         0.404 |
| (8, 16, 512, 128)   | relative_bias | torch.bfloat16 |           72.105 |              74.189 |          313.806 |             769.238 |         0.972 |         0.408 |
| (8, 16, 512, 128)   | head_bias     | torch.bfloat16 |           72.005 |              74.364 |          313.509 |             775.237 |         0.968 |         0.404 |
| (8, 16, 512, 256)   | noop          | torch.bfloat16 |          138.656 |             165.453 |          663.707 |            2672.067 |         0.838 |         0.248 |
| (8, 16, 512, 256)   | causal_mask   | torch.bfloat16 |          139.096 |             172.613 |          663.593 |            2926.538 |         0.806 |         0.227 |
| (8, 16, 512, 256)   | relative_bias | torch.bfloat16 |          139.500 |             168.417 |          663.938 |            2658.629 |         0.828 |         0.250 |
| (8, 16, 512, 256)   | head_bias     | torch.bfloat16 |          139.776 |             173.549 |          662.920 |            2667.266 |         0.805 |         0.249 |
| (8, 16, 1024, 64)   | noop          | torch.bfloat16 |          134.883 |             125.004 |          484.706 |            1195.254 |         1.079 |         0.406 |
| (8, 16, 1024, 64)   | causal_mask   | torch.bfloat16 |          134.297 |             132.875 |          485.420 |            1234.953 |         1.011 |         0.393 |
| (8, 16, 1024, 64)   | relative_bias | torch.bfloat16 |          134.839 |             139.231 |          485.470 |            1198.556 |         0.968 |         0.405 |
| (8, 16, 1024, 64)   | head_bias     | torch.bfloat16 |          133.822 |             136.449 |          485.608 |            1189.198 |         0.981 |         0.408 |
| (8, 16, 1024, 128)  | noop          | torch.bfloat16 |          235.470 |             234.765 |          886.094 |            2662.944 |         1.003 |         0.333 |
| (8, 16, 1024, 128)  | causal_mask   | torch.bfloat16 |          236.305 |             241.382 |          886.293 |            2646.984 |         0.979 |         0.335 |
| (8, 16, 1024, 128)  | relative_bias | torch.bfloat16 |          236.414 |             233.980 |          885.250 |            2642.178 |         1.010 |         0.335 |
| (8, 16, 1024, 128)  | head_bias     | torch.bfloat16 |          237.176 |             239.040 |          885.754 |            2665.242 |         0.992 |         0.332 |
| (8, 16, 1024, 256)  | noop          | torch.bfloat16 |          504.445 |             517.855 |         1978.956 |            9592.906 |         0.974 |         0.206 |
| (8, 16, 1024, 256)  | causal_mask   | torch.bfloat16 |          502.428 |             536.002 |         1978.611 |           10607.342 |         0.937 |         0.187 |
| (8, 16, 1024, 256)  | relative_bias | torch.bfloat16 |          503.396 |             523.960 |         1977.993 |            9539.284 |         0.961 |         0.207 |
| (8, 16, 1024, 256)  | head_bias     | torch.bfloat16 |          503.818 |             536.014 |         1980.131 |            9576.262 |         0.940 |         0.207 |
| (8, 16, 4096, 64)   | noop          | torch.bfloat16 |         1970.139 |            1674.930 |         5750.940 |           16724.134 |         1.176 |         0.344 |
| (8, 16, 4096, 64)   | causal_mask   | torch.bfloat16 |         1959.036 |            1775.056 |         5780.512 |           17390.350 |         1.104 |         0.332 |
| (8, 16, 4096, 64)   | relative_bias | torch.bfloat16 |         1947.198 |            1773.869 |         5780.643 |           16779.699 |         1.098 |         0.345 |
| (8, 16, 4096, 64)   | head_bias     | torch.bfloat16 |         1963.935 |            1829.502 |         5780.018 |           16703.259 |         1.073 |         0.346 |
| (8, 16, 4096, 128)  | noop          | torch.bfloat16 |         3582.711 |            3362.623 |        10436.069 |           36415.565 |         1.065 |         0.287 |
| (8, 16, 4096, 128)  | causal_mask   | torch.bfloat16 |         3581.504 |            3499.472 |        10346.869 |           36164.959 |         1.023 |         0.286 |
| (8, 16, 4096, 128)  | relative_bias | torch.bfloat16 |         3589.779 |            3337.849 |        10529.621 |           36261.696 |         1.075 |         0.290 |
| (8, 16, 4096, 128)  | head_bias     | torch.bfloat16 |         3602.265 |            3436.444 |        10458.660 |           36507.790 |         1.048 |         0.286 |
| (8, 16, 4096, 256)  | noop          | torch.bfloat16 |         7695.923 |            7126.275 |        24643.009 |          140949.081 |         1.080 |         0.175 |
| (8, 16, 4096, 256)  | causal_mask   | torch.bfloat16 |         7679.939 |            7186.252 |        24538.105 |          157156.067 |         1.069 |         0.156 |
| (8, 16, 4096, 256)  | relative_bias | torch.bfloat16 |         7681.374 |            6994.832 |        24549.713 |          140077.179 |         1.098 |         0.175 |
| (8, 16, 4096, 256)  | head_bias     | torch.bfloat16 |         7679.822 |            7212.278 |        24627.823 |          140675.003 |         1.065 |         0.175 |
| (16, 16, 512, 64)   | noop          | torch.bfloat16 |           80.126 |              78.291 |          333.719 |             541.165 |         1.023 |         0.617 |
| (16, 16, 512, 64)   | causal_mask   | torch.bfloat16 |           80.065 |              81.696 |          333.779 |             551.113 |         0.980 |         0.606 |
| (16, 16, 512, 64)   | relative_bias | torch.bfloat16 |           80.138 |              86.715 |          333.364 |             542.118 |         0.924 |         0.615 |
| (16, 16, 512, 64)   | head_bias     | torch.bfloat16 |           80.415 |              85.204 |          333.294 |             536.840 |         0.944 |         0.621 |
| (16, 16, 512, 128)  | noop          | torch.bfloat16 |          134.964 |             138.025 |          607.093 |            1333.102 |         0.978 |         0.455 |
| (16, 16, 512, 128)  | causal_mask   | torch.bfloat16 |          134.192 |             141.523 |          606.269 |            1424.318 |         0.948 |         0.426 |
| (16, 16, 512, 128)  | relative_bias | torch.bfloat16 |          135.711 |             138.639 |          606.283 |            1327.974 |         0.979 |         0.457 |
| (16, 16, 512, 128)  | head_bias     | torch.bfloat16 |          135.552 |             140.555 |          607.107 |            1347.370 |         0.964 |         0.451 |
| (16, 16, 512, 256)  | noop          | torch.bfloat16 |          275.113 |             315.144 |         1301.583 |            5268.153 |         0.873 |         0.247 |
| (16, 16, 512, 256)  | causal_mask   | torch.bfloat16 |          274.867 |             328.106 |         1302.513 |            5770.594 |         0.838 |         0.226 |
| (16, 16, 512, 256)  | relative_bias | torch.bfloat16 |          276.052 |             321.770 |         1302.904 |            5241.920 |         0.858 |         0.249 |
| (16, 16, 512, 256)  | head_bias     | torch.bfloat16 |          271.409 |             328.839 |         1302.142 |            5266.037 |         0.825 |         0.247 |
| (16, 16, 1024, 64)  | noop          | torch.bfloat16 |          260.489 |             237.463 |          955.884 |            1817.558 |         1.097 |         0.526 |
| (16, 16, 1024, 64)  | causal_mask   | torch.bfloat16 |          262.378 |             254.350 |          955.280 |            1843.807 |         1.032 |         0.518 |
| (16, 16, 1024, 64)  | relative_bias | torch.bfloat16 |          261.338 |             268.253 |          956.038 |            1820.036 |         0.974 |         0.525 |
| (16, 16, 1024, 64)  | head_bias     | torch.bfloat16 |          262.153 |             264.156 |          956.023 |            1810.076 |         0.992 |         0.528 |
| (16, 16, 1024, 128) | noop          | torch.bfloat16 |          476.475 |             461.413 |         1760.578 |            4306.521 |         1.033 |         0.409 |
| (16, 16, 1024, 128) | causal_mask   | torch.bfloat16 |          473.794 |             479.178 |         1761.277 |            4619.439 |         0.989 |         0.381 |
| (16, 16, 1024, 128) | relative_bias | torch.bfloat16 |          473.839 |             463.282 |         1758.692 |            4290.562 |         1.023 |         0.410 |
| (16, 16, 1024, 128) | head_bias     | torch.bfloat16 |          472.979 |             472.896 |         1763.086 |            4367.931 |         1.000 |         0.404 |
| (16, 16, 1024, 256) | noop          | torch.bfloat16 |         1014.184 |            1026.764 |         3922.997 |           19104.147 |         0.988 |         0.205 |
| (16, 16, 1024, 256) | causal_mask   | torch.bfloat16 |         1013.217 |            1039.046 |         3928.382 |           21086.281 |         0.975 |         0.186 |
| (16, 16, 1024, 256) | relative_bias | torch.bfloat16 |         1008.519 |            1015.278 |         3922.133 |           18980.652 |         0.993 |         0.207 |
| (16, 16, 1024, 256) | head_bias     | torch.bfloat16 |         1011.360 |            1047.542 |         3931.245 |           19069.172 |         0.965 |         0.206 |
| (16, 16, 4096, 64)  | noop          | torch.bfloat16 |         3929.850 |            3325.667 |        11411.704 |           23344.280 |         1.182 |         0.489 |
| (16, 16, 4096, 64)  | causal_mask   | torch.bfloat16 |         3885.262 |            3581.544 |        11390.515 |           23725.639 |         1.085 |         0.480 |
| (16, 16, 4096, 64)  | relative_bias | torch.bfloat16 |         3865.737 |            3537.308 |        11489.901 |           23406.330 |         1.093 |         0.491 |
| (16, 16, 4096, 64)  | head_bias     | torch.bfloat16 |         3880.530 |            3665.249 |        11484.411 |           23299.496 |         1.059 |         0.493 |
| (16, 16, 4096, 128) | noop          | torch.bfloat16 |         7030.306 |            6745.715 |        20621.264 |           57464.096 |         1.042 |         0.359 |
| (16, 16, 4096, 128) | causal_mask   | torch.bfloat16 |         7095.414 |            7034.385 |        20410.656 |           61660.511 |         1.009 |         0.331 |
| (16, 16, 4096, 128) | relative_bias | torch.bfloat16 |         7084.779 |            6686.497 |        20315.161 |           57243.969 |         1.060 |         0.355 |
| (16, 16, 4096, 128) | head_bias     | torch.bfloat16 |         7075.367 |            6863.305 |        20494.385 |           58481.953 |         1.031 |         0.350 |
| (16, 16, 4096, 256) | noop          | torch.bfloat16 |        15612.741 |           14297.482 |        55306.847 |          281161.865 |         1.092 |         0.197 |
| (16, 16, 4096, 256) | causal_mask   | torch.bfloat16 |        15326.592 |           14263.878 |        55227.806 |          313063.232 |         1.075 |         0.176 |
| (16, 16, 4096, 256) | relative_bias | torch.bfloat16 |        15297.963 |           14007.379 |        54558.029 |          279529.175 |         1.092 |         0.195 |
| (16, 16, 4096, 256) | head_bias     | torch.bfloat16 |        15216.160 |           14276.027 |        55081.581 |          280996.826 |         1.066 |         0.196 |

</details>

Pull Request resolved: pytorch#125515
Approved by: https://github.com/Chillee
ZelboK pushed a commit to ZelboK/pytorch that referenced this pull request May 19, 2024
ZelboK pushed a commit to ZelboK/pytorch that referenced this pull request May 19, 2024
# Summary
#### What does this PR do?
It enables Inductor to actually generate the fused flex attention kernel for the backwards

I did some other things along the way:
- Abstract out the 'build_subgraph_buffer' subroutine and make it reusable between flex attention and flex_attention backwards. In total we need too build 3 subgraphs for fwd + bwd. 1 for the fwd graph and then 2 in the bwd. The FAv2 algorithm recomputes the parts of the forward (more efficiently since we already have the row_max via logsumexp), therefore we need to inline both the fwd graph and the joint graph in the bwds kernel.
- The version of the backwards kernel is from a somewhat older version of the triton tutorial implementation. I think that we should update in a follow up to a newer version. Notably the blocks need to be square for this to work as currently implemented. I am sure there are many opportunities for optimization.
- I didnt correctly register the decomp table + IndexMode when I landed: pytorch#123902, this remedies that.
- The rel_bias helper func was reversed in terms of causality. I updated and then add a test specific for "future causal" attention.
- This PRs but the main point that I think still needs to be worked out is the store_output call. I have it hacked up to be 'fake' but I dont think we want to land that and likely want to just have a mutated 'dq' and a stored_output 'dk'
- I also needed to update the `TritonTemplateKernel` to actually accept multiple subgraphs (modifications)
- I updated the benchmark to also profile bwds performance

### Benchmark Numbers:
_The current implementation is not parallelizing over ctx length in the bwd_
FWD Speedups

| Type    |   Speedup | shape              | score_mod   | dtype          |
|---------|-----------|--------------------|-------------|----------------|
| Average |     0.991 |                    |             |                |
| Max     |     1.182 | (16, 16, 4096, 64) | noop        | torch.bfloat16 |
| Min     |     0.796 | (2, 16, 512, 256)  | head_bias   | torch.bfloat16 |

BWD Speedups

| Type    |   Speedup | shape              | score_mod   | dtype          |
|---------|-----------|--------------------|-------------|----------------|
| Average |     0.291 |                    |             |                |
| Max     |     0.652 | (8, 16, 512, 64)   | head_bias   | torch.bfloat16 |
| Min     |     0.073 | (2, 16, 4096, 128) | head_bias   | torch.bfloat16 |

<details>

<summary>Full Data</summary>

| shape               | score_mod     | dtype          |   fwd_eager_time |   fwd_compiled_time |   bwd_eager_time |   bwd_compiled_time |   fwd_speedup |   bwd_speedup |
|---------------------|---------------|----------------|------------------|---------------------|------------------|---------------------|---------------|---------------|
| (2, 16, 512, 64)    | noop          | torch.bfloat16 |           19.936 |              19.092 |           57.851 |             193.564 |         1.044 |         0.299 |
| (2, 16, 512, 64)    | causal_mask   | torch.bfloat16 |           19.955 |              19.497 |           57.662 |             206.278 |         1.024 |         0.280 |
| (2, 16, 512, 64)    | relative_bias | torch.bfloat16 |           19.455 |              21.297 |           57.674 |             195.219 |         0.913 |         0.295 |
| (2, 16, 512, 64)    | head_bias     | torch.bfloat16 |           19.958 |              21.289 |           57.674 |             193.859 |         0.938 |         0.298 |
| (2, 16, 512, 128)   | noop          | torch.bfloat16 |           28.157 |              28.615 |           82.831 |             454.211 |         0.984 |         0.182 |
| (2, 16, 512, 128)   | causal_mask   | torch.bfloat16 |           28.154 |              28.444 |           83.091 |             432.083 |         0.990 |         0.192 |
| (2, 16, 512, 128)   | relative_bias | torch.bfloat16 |           28.722 |              27.897 |           83.175 |             446.789 |         1.030 |         0.186 |
| (2, 16, 512, 128)   | head_bias     | torch.bfloat16 |           28.299 |              27.673 |           83.052 |             459.179 |         1.023 |         0.181 |
| (2, 16, 512, 256)   | noop          | torch.bfloat16 |           41.167 |              50.504 |          175.019 |            1083.545 |         0.815 |         0.162 |
| (2, 16, 512, 256)   | causal_mask   | torch.bfloat16 |           41.656 |              51.933 |          175.078 |            1171.176 |         0.802 |         0.149 |
| (2, 16, 512, 256)   | relative_bias | torch.bfloat16 |           41.697 |              50.722 |          175.159 |            1097.312 |         0.822 |         0.160 |
| (2, 16, 512, 256)   | head_bias     | torch.bfloat16 |           41.690 |              52.387 |          175.184 |            1097.336 |         0.796 |         0.160 |
| (2, 16, 1024, 64)   | noop          | torch.bfloat16 |           39.232 |              37.454 |          127.847 |             612.430 |         1.047 |         0.209 |
| (2, 16, 1024, 64)   | causal_mask   | torch.bfloat16 |           39.930 |              39.599 |          127.755 |             665.359 |         1.008 |         0.192 |
| (2, 16, 1024, 64)   | relative_bias | torch.bfloat16 |           39.417 |              41.304 |          127.902 |             614.990 |         0.954 |         0.208 |
| (2, 16, 1024, 64)   | head_bias     | torch.bfloat16 |           39.965 |              42.034 |          127.953 |             613.273 |         0.951 |         0.209 |
| (2, 16, 1024, 128)  | noop          | torch.bfloat16 |           63.964 |              71.024 |          226.510 |            1637.669 |         0.901 |         0.138 |
| (2, 16, 1024, 128)  | causal_mask   | torch.bfloat16 |           63.843 |              72.451 |          226.750 |            1558.949 |         0.881 |         0.145 |
| (2, 16, 1024, 128)  | relative_bias | torch.bfloat16 |           64.301 |              70.487 |          226.651 |            1610.063 |         0.912 |         0.141 |
| (2, 16, 1024, 128)  | head_bias     | torch.bfloat16 |           64.033 |              71.394 |          226.676 |            1668.511 |         0.897 |         0.136 |
| (2, 16, 1024, 256)  | noop          | torch.bfloat16 |          129.348 |             141.390 |          507.337 |            4405.175 |         0.915 |         0.115 |
| (2, 16, 1024, 256)  | causal_mask   | torch.bfloat16 |          129.538 |             145.680 |          507.178 |            4768.874 |         0.889 |         0.106 |
| (2, 16, 1024, 256)  | relative_bias | torch.bfloat16 |          129.438 |             142.782 |          507.004 |            4401.002 |         0.907 |         0.115 |
| (2, 16, 1024, 256)  | head_bias     | torch.bfloat16 |          129.058 |             146.242 |          507.547 |            4434.251 |         0.883 |         0.114 |
| (2, 16, 4096, 64)   | noop          | torch.bfloat16 |          481.606 |             409.120 |         1440.890 |           14147.269 |         1.177 |         0.102 |
| (2, 16, 4096, 64)   | causal_mask   | torch.bfloat16 |          480.227 |             438.847 |         1434.419 |           14973.386 |         1.094 |         0.096 |
| (2, 16, 4096, 64)   | relative_bias | torch.bfloat16 |          480.831 |             458.104 |         1432.935 |           14193.253 |         1.050 |         0.101 |
| (2, 16, 4096, 64)   | head_bias     | torch.bfloat16 |          480.749 |             452.497 |         1437.040 |           14084.869 |         1.062 |         0.102 |
| (2, 16, 4096, 128)  | noop          | torch.bfloat16 |          872.534 |             848.275 |         2600.895 |           35156.849 |         1.029 |         0.074 |
| (2, 16, 4096, 128)  | causal_mask   | torch.bfloat16 |          872.647 |             868.279 |         2587.581 |           31919.531 |         1.005 |         0.081 |
| (2, 16, 4096, 128)  | relative_bias | torch.bfloat16 |          871.484 |             827.644 |         2593.989 |           34805.634 |         1.053 |         0.075 |
| (2, 16, 4096, 128)  | head_bias     | torch.bfloat16 |          871.422 |             856.437 |         2602.482 |           35708.591 |         1.017 |         0.073 |
| (2, 16, 4096, 256)  | noop          | torch.bfloat16 |         1904.497 |            1758.183 |         6122.416 |           66754.593 |         1.083 |         0.092 |
| (2, 16, 4096, 256)  | causal_mask   | torch.bfloat16 |         1911.174 |            1762.821 |         6113.207 |           72759.392 |         1.084 |         0.084 |
| (2, 16, 4096, 256)  | relative_bias | torch.bfloat16 |         1911.254 |            1727.108 |         6123.530 |           66577.988 |         1.107 |         0.092 |
| (2, 16, 4096, 256)  | head_bias     | torch.bfloat16 |         1916.977 |            1801.804 |         6118.158 |           67359.680 |         1.064 |         0.091 |
| (8, 16, 512, 64)    | noop          | torch.bfloat16 |           44.984 |              43.974 |          170.276 |             262.259 |         1.023 |         0.649 |
| (8, 16, 512, 64)    | causal_mask   | torch.bfloat16 |           45.001 |              46.265 |          170.509 |             274.893 |         0.973 |         0.620 |
| (8, 16, 512, 64)    | relative_bias | torch.bfloat16 |           45.466 |              48.211 |          170.606 |             262.759 |         0.943 |         0.649 |
| (8, 16, 512, 64)    | head_bias     | torch.bfloat16 |           45.481 |              48.435 |          170.267 |             261.265 |         0.939 |         0.652 |
| (8, 16, 512, 128)   | noop          | torch.bfloat16 |           72.565 |              74.736 |          313.220 |             773.126 |         0.971 |         0.405 |
| (8, 16, 512, 128)   | causal_mask   | torch.bfloat16 |           72.015 |              75.755 |          313.311 |             775.513 |         0.951 |         0.404 |
| (8, 16, 512, 128)   | relative_bias | torch.bfloat16 |           72.105 |              74.189 |          313.806 |             769.238 |         0.972 |         0.408 |
| (8, 16, 512, 128)   | head_bias     | torch.bfloat16 |           72.005 |              74.364 |          313.509 |             775.237 |         0.968 |         0.404 |
| (8, 16, 512, 256)   | noop          | torch.bfloat16 |          138.656 |             165.453 |          663.707 |            2672.067 |         0.838 |         0.248 |
| (8, 16, 512, 256)   | causal_mask   | torch.bfloat16 |          139.096 |             172.613 |          663.593 |            2926.538 |         0.806 |         0.227 |
| (8, 16, 512, 256)   | relative_bias | torch.bfloat16 |          139.500 |             168.417 |          663.938 |            2658.629 |         0.828 |         0.250 |
| (8, 16, 512, 256)   | head_bias     | torch.bfloat16 |          139.776 |             173.549 |          662.920 |            2667.266 |         0.805 |         0.249 |
| (8, 16, 1024, 64)   | noop          | torch.bfloat16 |          134.883 |             125.004 |          484.706 |            1195.254 |         1.079 |         0.406 |
| (8, 16, 1024, 64)   | causal_mask   | torch.bfloat16 |          134.297 |             132.875 |          485.420 |            1234.953 |         1.011 |         0.393 |
| (8, 16, 1024, 64)   | relative_bias | torch.bfloat16 |          134.839 |             139.231 |          485.470 |            1198.556 |         0.968 |         0.405 |
| (8, 16, 1024, 64)   | head_bias     | torch.bfloat16 |          133.822 |             136.449 |          485.608 |            1189.198 |         0.981 |         0.408 |
| (8, 16, 1024, 128)  | noop          | torch.bfloat16 |          235.470 |             234.765 |          886.094 |            2662.944 |         1.003 |         0.333 |
| (8, 16, 1024, 128)  | causal_mask   | torch.bfloat16 |          236.305 |             241.382 |          886.293 |            2646.984 |         0.979 |         0.335 |
| (8, 16, 1024, 128)  | relative_bias | torch.bfloat16 |          236.414 |             233.980 |          885.250 |            2642.178 |         1.010 |         0.335 |
| (8, 16, 1024, 128)  | head_bias     | torch.bfloat16 |          237.176 |             239.040 |          885.754 |            2665.242 |         0.992 |         0.332 |
| (8, 16, 1024, 256)  | noop          | torch.bfloat16 |          504.445 |             517.855 |         1978.956 |            9592.906 |         0.974 |         0.206 |
| (8, 16, 1024, 256)  | causal_mask   | torch.bfloat16 |          502.428 |             536.002 |         1978.611 |           10607.342 |         0.937 |         0.187 |
| (8, 16, 1024, 256)  | relative_bias | torch.bfloat16 |          503.396 |             523.960 |         1977.993 |            9539.284 |         0.961 |         0.207 |
| (8, 16, 1024, 256)  | head_bias     | torch.bfloat16 |          503.818 |             536.014 |         1980.131 |            9576.262 |         0.940 |         0.207 |
| (8, 16, 4096, 64)   | noop          | torch.bfloat16 |         1970.139 |            1674.930 |         5750.940 |           16724.134 |         1.176 |         0.344 |
| (8, 16, 4096, 64)   | causal_mask   | torch.bfloat16 |         1959.036 |            1775.056 |         5780.512 |           17390.350 |         1.104 |         0.332 |
| (8, 16, 4096, 64)   | relative_bias | torch.bfloat16 |         1947.198 |            1773.869 |         5780.643 |           16779.699 |         1.098 |         0.345 |
| (8, 16, 4096, 64)   | head_bias     | torch.bfloat16 |         1963.935 |            1829.502 |         5780.018 |           16703.259 |         1.073 |         0.346 |
| (8, 16, 4096, 128)  | noop          | torch.bfloat16 |         3582.711 |            3362.623 |        10436.069 |           36415.565 |         1.065 |         0.287 |
| (8, 16, 4096, 128)  | causal_mask   | torch.bfloat16 |         3581.504 |            3499.472 |        10346.869 |           36164.959 |         1.023 |         0.286 |
| (8, 16, 4096, 128)  | relative_bias | torch.bfloat16 |         3589.779 |            3337.849 |        10529.621 |           36261.696 |         1.075 |         0.290 |
| (8, 16, 4096, 128)  | head_bias     | torch.bfloat16 |         3602.265 |            3436.444 |        10458.660 |           36507.790 |         1.048 |         0.286 |
| (8, 16, 4096, 256)  | noop          | torch.bfloat16 |         7695.923 |            7126.275 |        24643.009 |          140949.081 |         1.080 |         0.175 |
| (8, 16, 4096, 256)  | causal_mask   | torch.bfloat16 |         7679.939 |            7186.252 |        24538.105 |          157156.067 |         1.069 |         0.156 |
| (8, 16, 4096, 256)  | relative_bias | torch.bfloat16 |         7681.374 |            6994.832 |        24549.713 |          140077.179 |         1.098 |         0.175 |
| (8, 16, 4096, 256)  | head_bias     | torch.bfloat16 |         7679.822 |            7212.278 |        24627.823 |          140675.003 |         1.065 |         0.175 |
| (16, 16, 512, 64)   | noop          | torch.bfloat16 |           80.126 |              78.291 |          333.719 |             541.165 |         1.023 |         0.617 |
| (16, 16, 512, 64)   | causal_mask   | torch.bfloat16 |           80.065 |              81.696 |          333.779 |             551.113 |         0.980 |         0.606 |
| (16, 16, 512, 64)   | relative_bias | torch.bfloat16 |           80.138 |              86.715 |          333.364 |             542.118 |         0.924 |         0.615 |
| (16, 16, 512, 64)   | head_bias     | torch.bfloat16 |           80.415 |              85.204 |          333.294 |             536.840 |         0.944 |         0.621 |
| (16, 16, 512, 128)  | noop          | torch.bfloat16 |          134.964 |             138.025 |          607.093 |            1333.102 |         0.978 |         0.455 |
| (16, 16, 512, 128)  | causal_mask   | torch.bfloat16 |          134.192 |             141.523 |          606.269 |            1424.318 |         0.948 |         0.426 |
| (16, 16, 512, 128)  | relative_bias | torch.bfloat16 |          135.711 |             138.639 |          606.283 |            1327.974 |         0.979 |         0.457 |
| (16, 16, 512, 128)  | head_bias     | torch.bfloat16 |          135.552 |             140.555 |          607.107 |            1347.370 |         0.964 |         0.451 |
| (16, 16, 512, 256)  | noop          | torch.bfloat16 |          275.113 |             315.144 |         1301.583 |            5268.153 |         0.873 |         0.247 |
| (16, 16, 512, 256)  | causal_mask   | torch.bfloat16 |          274.867 |             328.106 |         1302.513 |            5770.594 |         0.838 |         0.226 |
| (16, 16, 512, 256)  | relative_bias | torch.bfloat16 |          276.052 |             321.770 |         1302.904 |            5241.920 |         0.858 |         0.249 |
| (16, 16, 512, 256)  | head_bias     | torch.bfloat16 |          271.409 |             328.839 |         1302.142 |            5266.037 |         0.825 |         0.247 |
| (16, 16, 1024, 64)  | noop          | torch.bfloat16 |          260.489 |             237.463 |          955.884 |            1817.558 |         1.097 |         0.526 |
| (16, 16, 1024, 64)  | causal_mask   | torch.bfloat16 |          262.378 |             254.350 |          955.280 |            1843.807 |         1.032 |         0.518 |
| (16, 16, 1024, 64)  | relative_bias | torch.bfloat16 |          261.338 |             268.253 |          956.038 |            1820.036 |         0.974 |         0.525 |
| (16, 16, 1024, 64)  | head_bias     | torch.bfloat16 |          262.153 |             264.156 |          956.023 |            1810.076 |         0.992 |         0.528 |
| (16, 16, 1024, 128) | noop          | torch.bfloat16 |          476.475 |             461.413 |         1760.578 |            4306.521 |         1.033 |         0.409 |
| (16, 16, 1024, 128) | causal_mask   | torch.bfloat16 |          473.794 |             479.178 |         1761.277 |            4619.439 |         0.989 |         0.381 |
| (16, 16, 1024, 128) | relative_bias | torch.bfloat16 |          473.839 |             463.282 |         1758.692 |            4290.562 |         1.023 |         0.410 |
| (16, 16, 1024, 128) | head_bias     | torch.bfloat16 |          472.979 |             472.896 |         1763.086 |            4367.931 |         1.000 |         0.404 |
| (16, 16, 1024, 256) | noop          | torch.bfloat16 |         1014.184 |            1026.764 |         3922.997 |           19104.147 |         0.988 |         0.205 |
| (16, 16, 1024, 256) | causal_mask   | torch.bfloat16 |         1013.217 |            1039.046 |         3928.382 |           21086.281 |         0.975 |         0.186 |
| (16, 16, 1024, 256) | relative_bias | torch.bfloat16 |         1008.519 |            1015.278 |         3922.133 |           18980.652 |         0.993 |         0.207 |
| (16, 16, 1024, 256) | head_bias     | torch.bfloat16 |         1011.360 |            1047.542 |         3931.245 |           19069.172 |         0.965 |         0.206 |
| (16, 16, 4096, 64)  | noop          | torch.bfloat16 |         3929.850 |            3325.667 |        11411.704 |           23344.280 |         1.182 |         0.489 |
| (16, 16, 4096, 64)  | causal_mask   | torch.bfloat16 |         3885.262 |            3581.544 |        11390.515 |           23725.639 |         1.085 |         0.480 |
| (16, 16, 4096, 64)  | relative_bias | torch.bfloat16 |         3865.737 |            3537.308 |        11489.901 |           23406.330 |         1.093 |         0.491 |
| (16, 16, 4096, 64)  | head_bias     | torch.bfloat16 |         3880.530 |            3665.249 |        11484.411 |           23299.496 |         1.059 |         0.493 |
| (16, 16, 4096, 128) | noop          | torch.bfloat16 |         7030.306 |            6745.715 |        20621.264 |           57464.096 |         1.042 |         0.359 |
| (16, 16, 4096, 128) | causal_mask   | torch.bfloat16 |         7095.414 |            7034.385 |        20410.656 |           61660.511 |         1.009 |         0.331 |
| (16, 16, 4096, 128) | relative_bias | torch.bfloat16 |         7084.779 |            6686.497 |        20315.161 |           57243.969 |         1.060 |         0.355 |
| (16, 16, 4096, 128) | head_bias     | torch.bfloat16 |         7075.367 |            6863.305 |        20494.385 |           58481.953 |         1.031 |         0.350 |
| (16, 16, 4096, 256) | noop          | torch.bfloat16 |        15612.741 |           14297.482 |        55306.847 |          281161.865 |         1.092 |         0.197 |
| (16, 16, 4096, 256) | causal_mask   | torch.bfloat16 |        15326.592 |           14263.878 |        55227.806 |          313063.232 |         1.075 |         0.176 |
| (16, 16, 4096, 256) | relative_bias | torch.bfloat16 |        15297.963 |           14007.379 |        54558.029 |          279529.175 |         1.092 |         0.195 |
| (16, 16, 4096, 256) | head_bias     | torch.bfloat16 |        15216.160 |           14276.027 |        55081.581 |          280996.826 |         1.066 |         0.196 |

</details>

Pull Request resolved: pytorch#125515
Approved by: https://github.com/Chillee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants